{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "import os\n",
    "from collections import defaultdict\n",
    "import random\n",
    "import numpy as np\n",
    "\n",
    "np.random.seed(1234)\n",
    "random.seed(1234)\n",
    "\n",
    "%env CUDA_VISIBLE_DEVICES=0\n",
    "\n",
    "\n",
    "src_lang, dst_lang = 'de', 'en'\n",
    "checkpoint_path = 'baseline_checkpoint.pt'\n",
    "data_path = 'data-bin/iwslt14.tokenized.{}-{}/'.format(src_lang, dst_lang)\n",
    "output_path = 'edit_iwslt14.tokenized.{}-{}'.format(src_lang, dst_lang)\n",
    "beam = nbest = 32\n",
    "max_tokens = 1024\n",
    "keys = ['valid', 'test', 'train']\n",
    "\n",
    "\n",
    "sys.path.append('fairseq')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if os.path.isdir(output_path):\n",
    "    if len(os.listdir(output_path)) > 0:\n",
    "        raise ValueError('Output directory {} is not empty'.format(output_path))\n",
    "else:\n",
    "    os.makedirs(output_path, exist_ok=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if not os.path.isfile(checkpoint_path):\n",
    "    ! wget https://www.dropbox.com/s/iksezig9qi4g92e/baseline_checkpoint.pt?dl=1 -O {checkpoint_path}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from fairseq.data.dictionary import Dictionary\n",
    "\n",
    "src_voc = Dictionary.load(os.path.join(data_path, 'dict.{}.txt'.format(src_lang)))\n",
    "dst_voc = Dictionary.load(os.path.join(data_path, 'dict.{}.txt'.format(dst_lang)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def generate_alternatives(key, file):\n",
    "    !fairseq-generate {data_path} --path {checkpoint_path} \\\n",
    "     --beam {beam} --nbest {nbest} \\\n",
    "     --max-tokens {max_tokens} \\\n",
    "     --gen-subset {key} > {file}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def generate_edits(key, file):\n",
    "    !fairseq-generate {data_path} --path {checkpoint_path} \\\n",
    "     --beam {beam} --nbest {nbest} \\\n",
    "     --max-tokens {max_tokens} \\\n",
    "     --sampling --temperature 1.2 \\\n",
    "     --gen-subset {key} > {file}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for key in keys:\n",
    "    tmp_alternatives_output_file_path = os.path.join(output_path, 'tmp_alternaatives_{}.txt'.format(key))\n",
    "    tmp_edits_output_file_path = os.path.join(output_path, 'tmp_edits_{}.txt'.format(key))\n",
    "    bpe_output_file_path = os.path.join(output_path, 'bpe_{}.txt'.format(key))\n",
    "    \n",
    "    print('Start generating {} alternatives'.format(key))\n",
    "    generate_alternatives(key, tmp_alternatives_output_file_path)\n",
    "    print('Start generating {} edits'.format(key))\n",
    "    generate_edits(key, tmp_edits_output_file_path)\n",
    "\n",
    "    logs = defaultdict(list)\n",
    "\n",
    "    print('Start parsing the alternatives beam search output')\n",
    "\n",
    "    with open(tmp_alternatives_output_file_path) as f_in:\n",
    "        for line in f_in:\n",
    "            try:\n",
    "                tag, *payload = line.split('\\t')\n",
    "                logs[tag].append(payload)\n",
    "            except: pass\n",
    "            \n",
    "    print('Start parsing the edits beam search output')\n",
    "    \n",
    "    edit_logs = defaultdict(list)\n",
    "\n",
    "    with open(tmp_edits_output_file_path) as f_in:\n",
    "        for line in f_in:\n",
    "            try:\n",
    "                tag, *payload = line.split('\\t')\n",
    "                edit_logs[tag].append(payload)\n",
    "            except: pass\n",
    "    \n",
    "    print('Start edit samples generating')\n",
    "\n",
    "    with open(bpe_output_file_path, 'w') as f_out:\n",
    "        i = 0\n",
    "        while True:\n",
    "            i += 1\n",
    "            try:\n",
    "                source = logs['S-{}'.format(i)][0][0].strip()\n",
    "                edit_source = edit_logs['S-{}'.format(i)][0][0].strip()\n",
    "            except:\n",
    "                break\n",
    "            hypos = [hypo.strip() for prob, hypo in logs['H-{}'.format(i)]]\n",
    "            \n",
    "            edits = [edit.strip() for prob, edit in edit_logs['H-{}'.format(i)]]\n",
    "            edit = random.choice(edits)\n",
    "\n",
    "            np.random.shuffle(hypos)\n",
    "            f_out.write('{}\\t{}\\t{}\\n'.format(source, edit, '\\t'.join(hypos)))\n",
    "            \n",
    "    print('_'*100)\n",
    "    print('\\n'*3)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
